{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77be7dda",
   "metadata": {},
   "source": [
    "## vllm-lmi rollingbatch Yi-34B-chat-4bits deployment guide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d69a078d",
   "metadata": {},
   "source": [
    "### In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Yi-34B-chat-4bits and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* S3 bucket push access\n",
    "* SageMaker access\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc171a0",
   "metadata": {},
   "source": [
    "### Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd49db4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "271ce0b4-e9b2-4661-9767-7580ff6f86be",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install transformers sentencepiece --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3c2465",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab26348",
   "metadata": {},
   "source": [
    "### Step 2: Start preparing model artifacts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86df559e",
   "metadata": {},
   "source": [
    "In LMI container, we expect some artifacts to help setting up the model\n",
    "\n",
    "* serving.properties (required): Defines the model server settings\n",
    "* model.py (optional): A python file to define the core inference logic\n",
    "* requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19ff0dc9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "engine=Python\n",
    "option.model_id=01-ai/Yi-34B-Chat-4bits\n",
    "option.tensor_parallel_degree=4\n",
    "option.max_rolling_batch_size=64\n",
    "option.rolling_batch=vllm\n",
    "option.quantize=awq\n",
    "option.dtype=fp16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c21521",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a817d0e",
   "metadata": {},
   "source": [
    "### Step 3: Start building SageMaker endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42b20125",
   "metadata": {},
   "source": [
    "#### Getting the container image URI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61a908b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-deepspeed\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96a08056",
   "metadata": {},
   "source": [
    "#### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f903af58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = \"01-ai/Yi-34B-Chat-4bits\"\n",
    "s3_code_prefix = f\"large-model-vllm/{model_name}code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04c778f",
   "metadata": {},
   "source": [
    "#### Create SageMaker endpoint with a specified instance type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f27a8df",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "instance_type = \"ml.g4dn.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(f\"lmi-model-{model_name.replace('/', '-')}\")\n",
    "print(f\"endpoint_name: {endpoint_name}\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             container_startup_health_check_timeout=1800\n",
    "            )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19004557",
   "metadata": {},
   "source": [
    "### Step 4: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f35d071-2c34-4919-baf8-bda7fef9d49c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "MODEL_DIR = model_name\n",
    "# model = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype=\"auto\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd192a22-3cda-4032-bdfd-0491e88cc068",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "# messages = [\n",
    "#     {\"role\": \"user\", \"content\": \"世界上第二高的山峰是哪座\"},\n",
    "#     {\"role\": \"assistant\", \"content\": \"世界上第二高的山峰是中国的珠穆朗玛峰。珠穆朗玛峰位于中国西藏自治区南部边境与尼泊尔交界的喜马拉雅山脉中，海拔高度为8848.86米（29,029英尺）。它是世界上海拔最高的山脉，同时也是中国大陆的最高峰。\\\\n\\\\n珠穆朗玛峰的名字来源于藏语，“珠穆”意为女神，“朗玛”意为母仪天下，整体意为“大地之母”。这座山峰不仅在登山界享有盛誉，也是全球登山爱好者和探险\"},\n",
    "#     {\"role\": \"user\", \"content\": \"一个句子总结\"}\n",
    "    \n",
    "# ]\n",
    "messages = [\n",
    "    {\"role\": \"user\", \"content\": \"世界上第二高的山峰是哪座\"},\n",
    "    \n",
    "]\n",
    "\n",
    "\n",
    "input_text = tokenizer.apply_chat_template(conversation=messages, tokenize=False, add_generation_prompt=True)\n",
    "print(json.dumps(input_text, ensure_ascii=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccbf0046",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "parameters = {\n",
    "                \"max_new_tokens\":128,\n",
    "                \"do_sample\":True,\n",
    "                \"temperature\": 0.6,\n",
    "                \"eos_token_id\": 7,\n",
    "                \"top_p\": 0.8\n",
    "            }\n",
    "response = predictor.predict(\n",
    "    {\"inputs\": input_text, \"parameters\": {\"max_new_tokens\":128, \"do_sample\":True}}\n",
    ")\n",
    "\n",
    "text = str(response, 'utf-8')\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f04cbcd1-3841-44da-b9fd-4b96e4c72783",
   "metadata": {},
   "source": [
    "## Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd17ba98-ffd1-49ed-95d9-10840c2fdae4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import boto3\n",
    "from utils.LineIterator import LineIterator\n",
    "\n",
    "smr_client = boto3.client(\"sagemaker-runtime\")\n",
    "def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):\n",
    "    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(\n",
    "        EndpointName=endpoint_name,\n",
    "        Body=json.dumps(payload), \n",
    "        ContentType=\"application/json\",\n",
    "        CustomAttributes='accept_eula=false'\n",
    "    )\n",
    "    return response_stream\n",
    "\n",
    "\n",
    "\n",
    "def print_response_stream(response_stream):\n",
    "    event_stream = response_stream.get('Body')\n",
    "    for line in LineIterator(event_stream):\n",
    "        print(line, end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69410da3-91e0-459e-886c-61c282700507",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "parameters = {\n",
    "                \"max_new_tokens\":1024,\n",
    "                \"do_sample\":True,\n",
    "                \"temperature\": 0.6,\n",
    "                \"top_p\": 0.8,\n",
    "                \"repetition_penalty\": 1.2,\n",
    "                \"stop\":\"<|im_end|>\"\n",
    "            }\n",
    "payload = {\n",
    "    \"inputs\":  input_text,\n",
    "    \"parameters\": parameters,\n",
    "    \"stream\": True ## <-- to have response stream.\n",
    "}\n",
    "response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)\n",
    "print_response_stream(response_stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd8073fb-0880-4697-aedf-4711918cacbd",
   "metadata": {},
   "source": [
    "## Clear Resource"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9603928-b375-4d87-8c29-63b3cd59306a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569e42e2-b80b-41f9-812e-de39690be86d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
